# encoding:utf-8
import math
import base64
import logging
import datetime
from tornado.ioloop import IOLoop
from tornado.log import access_log
from tornado.options import options
from utils.signature import Signature
from tornado.gen import coroutine, Return
from baseHandler import BaseRequestHandler
from tornado.web import asynchronous, HTTPError
# 这个并发库在python3自带;在python2需要安装sudo pip install futures
from tornado.concurrent import run_on_executor
from utils.cephRESTApi import CephRestApi, Client
from tornado.websocket import websocket_connect, WebSocketClosedError
from model.pool import Pool
from sqlalchemy.orm.exc import NoResultFound
from middleware.authMiddleware import PermissionCheckMiddleware, CheckPoolMiddleware

check_superuser = PermissionCheckMiddleware()
try:
    import json
except ImportError:
    import simplejson as json

logger = logging.getLogger()


class XattrHandler(BaseRequestHandler):
    @coroutine
    def get(self):
        poolname = self.get_argument('poolname')
        obj = self.get_argument('object')
        res = yield self.ceph.obj_get_xattrs(poolname, obj)
        self.write(res)

    @coroutine
    def post(self):
        poolname = self.get_argument('poolname')
        obj = self.get_argument('object')
        key = self.get_argument('key')
        val = self.get_argument('val')
        res = yield self.ceph.obj_set_xattr(poolname, obj, key, val, )
        self.write(res)

    @coroutine
    def delete(self):
        poolname = self.get_argument('poolname', )
        object = self.get_argument('object')
        key = self.get_argument('key')
        res = yield self.ceph.obj_del_xattr(poolname, object, key, )
        self.write(res)


class PoolHandler(BaseRequestHandler):
    middleware = (check_superuser,)

    @asynchronous
    @coroutine
    def get(self):
        poolname = self.get_argument('poolname', None)
        ceph = CephRestApi()
        if poolname:
            var = self.get_argument('var', None)
            if var:
                if var in (u'quota',):
                    res = yield ceph.get_pool_quota(poolname)
                elif var in (u'stat',):
                    res = yield self.ceph.pool_info(poolname)
                elif var in (u'policy', u'type', u'html',):

                    def _lambda(db):
                        try:
                            pool = db.query(Pool).filter_by(name=poolname).one()
                            return {'status': 'OK', 'output': pool.to_dict()}
                        except NoResultFound:
                            return {'status': 'ERROR', 'output': u'not find pool'}

                    res = yield self.asyncdb(_lambda)
                else:
                    res = yield ceph.get_pool(poolname, var)

            else:
                res = yield ceph.get_pool(poolname)
        else:
            res = yield ceph.osd_lspools()
        self.write(res)

    @asynchronous
    @coroutine
    def post(self):
        poolname = self.get_argument('poolname')
        mode = self.get_argument('type')
        policy = self.get_argument('policy', 2)
        html = self.get_argument('html', 1)
        size = self.get_argument('size', None)
        ceph = CephRestApi()
        if mode in ('normal', 'performance', 'high-performance',):
            res = yield ceph.create_pool(poolname=poolname, ruleset=getattr(options, ''.join((mode, '_ruleset'))))
            if res.get('status') in ('OK',):
                yield ceph.set_pool(poolname, 'size', getattr(options, ''.join((mode, '_size'))))
                yield ceph.set_pool(poolname, 'crush_ruleset', getattr(options, ''.join((mode, '_ruleset'))))
                if size: yield ceph.set_size_or_object_num_for_pool(poolname, 'max_bytes', size)

                def _lambda(db):
                    try:
                        db.query(Pool).filter_by(name=poolname).one()
                    except NoResultFound:
                        pool = Pool(poolname, mold=mode, policy=policy, html=bool(int(html)))
                        db.add(pool)
                        db.commit()

                yield self.asyncdb(_lambda)
            self.write(res)
        else:
            raise HTTPError(403, 'check type error ')

    @asynchronous
    @coroutine
    def delete(self):
        poolname = self.get_argument('poolname')
        ceph = CephRestApi()
        res = yield ceph.delete_pool(poolname)

        def _lambda(db):
            try:
                pool = db.query(Pool).filter_by(name=poolname).one()
                db.delete(pool)
                db.commit()
            except NoResultFound:
                pass

        yield self.asyncdb(_lambda)
        self.write(res)

    @asynchronous
    @coroutine
    def put(self):
        poolname = self.get_argument('poolname')
        var = self.get_argument('var')
        val = self.get_argument('val')
        methods = ('max_objects', 'max_bytes',)
        ceph = CephRestApi()
        policy_and_html = (u'policy', u'html')
        if var in methods:
            res = yield ceph.set_size_or_object_num_for_pool(poolname, var, val)
            self.write(res)
        elif var in policy_and_html:
            def _lambda(db):
                try:
                    pool = db.query(Pool).filter_by(name=poolname).one()
                    setattr(pool, var, int(val))
                    db.add(pool)
                except NoResultFound:
                    pass

            yield self.asyncdb(_lambda)
            self.write({"status": 'OK', 'output': ''})
        else:
            res = yield ceph.set_pool(poolname, var, val)
            self.write(res)

    @asynchronous
    @coroutine
    def patch(self):
        poolname = self.get_argument('poolname')
        policy = self.get_argument('policy')
        mold = self.get_argument('mold')
        html = self.get_argument('html')

        def _lambda(db):
            try:
                db.query(Pool).filter_by(name=poolname).one()
            except NoResultFound:
                pool = Pool(name=poolname, policy=int(policy), mold=mold, html=int(html))
                db.add(pool)

        yield self.asyncdb(_lambda)
        self.write({"status": 'OK', 'output': ''})


class AuthHandler(BaseRequestHandler):
    @asynchronous
    @coroutine
    def get(self):
        api = CephRestApi()
        username = self.get_argument('username', None)
        if username:
            res = yield api.auth_get(entity='client.{0}'.format(username))
            self.write(res)
        else:
            res = yield api.auth_list()
            self.write(res)

    @asynchronous
    @coroutine
    def delete(self):
        api = CephRestApi()
        username = self.get_argument('username')
        res = yield api.auth_del(entity='client.{0}'.format(username))
        self.write(res)

    @asynchronous
    @coroutine
    def post(self):
        ceph = CephRestApi()
        username = self.get_argument('username')
        poolname = self.get_argument('poolname')
        pool = self.get_argument('sourcepoolname', None)
        sourcepoolname = pool if pool else options.ceph_source_pool
        res = yield ceph.auth_add(entity='client.{0}'.format(username),
                                  caps={'mon': 'allow r',
                                        'osd': 'allow class-read object_prefix rbd_children, allow rwx pool={0}, allow rx pool={1}'.format(
                                            poolname, sourcepoolname, )})
        self.write(res)


class RBDSnapHandler(BaseRequestHandler):
    middleware = (CheckPoolMiddleware(),)

    @coroutine
    def get(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        rest = yield self.ceph.list_rbd_snap(poolname, blockname, )
        self.write(rest)

    @coroutine
    def post(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        snapname = self.get_argument('snapname')
        rest = yield self.ceph.create_rbd_snap(poolname, blockname, snapname, )
        self.write(rest)

    @coroutine
    def delete(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        snapname = self.get_argument('snapname')
        rest = yield self.ceph.delete_rbd_snap(poolname, blockname, snapname)
        self.write(rest)

    def put(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        snapname = self.get_argument('snapname')
        params = self.get_argument('callback_params', None)
        callback_url = self.get_argument('callback', None)
        IOLoop.instance().add_callback(self.rollback,
                                       **{'poolname': poolname, 'blockname': blockname, 'snapname': snapname,
                                          'params': params, 'callback_url': callback_url})
        self.write({'status': 'OK', 'output': ''})

    @run_on_executor
    @coroutine
    def rollback(self, poolname, blockname, snapname, params, callback_url):
        start_time = datetime.datetime.now()
        rest = yield self.ceph.rollback_rbd_to_snap(poolname, blockname, snapname)
        end_time = datetime.datetime.now()
        url_params = list()
        run_time = (end_time - start_time).seconds
        access_log.info('rollback It takes %i seconds' % (run_time,))
        if params: url_params = params.split(',')
        if callback_url and u'http' in callback_url:
            url = '{0}?{1}&status={2}&output={3}'.format(callback_url, '&'.join(url_params),
                                                         rest.get('status'), rest.get('output'))
            client = Client(root_path='', response_json=False)
            start_time = datetime.datetime.now()
            yield client.get(url)
            end_time = datetime.datetime.now()
            run_time = (end_time - start_time).seconds
            access_log.info(msg='rollback {0} and callback address is {1} timespan {2} seconds'.format(rest.get('status'), url,run_time))


class RBDProtectHandler(BaseRequestHandler):
    middleware = (CheckPoolMiddleware(),)

    @coroutine
    def get(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        snapname = self.get_argument('snapname')
        res = yield self.ceph.rbd_is_protect_snap(poolname, blockname, snapname)
        self.write(res)

    @coroutine
    def post(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        snapname = self.get_argument('snapname')
        res = yield self.ceph.protect_rbd_snap(poolname, blockname, snapname, )
        self.write(res)

    @coroutine
    def delete(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        snapname = self.get_argument('snapname')
        res = yield self.ceph.unprotect_rbd_snap(poolname, blockname, snapname, )
        self.write(res)


class RBDManageHandler(BaseRequestHandler):
    middleware = (CheckPoolMiddleware(),)

    @coroutine
    def post(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        size = self.get_argument('size')
        res = yield self.ceph.resize_rbd(poolname, blockname, int(size), )
        self.write(res)


class RBDHandler(BaseRequestHandler):
    middleware = (CheckPoolMiddleware(),)

    @coroutine
    def get(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname', None)
        if blockname:
            ceph_rbd_info = yield self.ceph.rbd_info(poolname, blockname, )
            self.write(ceph_rbd_info)
        else:
            ceph_rbd_list = yield self.ceph.list_rbd(poolname, )
            self.write(ceph_rbd_list)

    @coroutine
    def post(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        size = self.get_argument('size')
        res = yield self.ceph.create_rbd(poolname, blockname, int(size), )
        self.write(res)

    @coroutine
    def put(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        snapname = self.get_argument('snapname')
        clonepoolname = self.get_argument('destpool')
        cloneblockname = self.get_argument('destblock')
        res = yield self.ceph.clone_rbd(poolname, blockname, snapname, clonepoolname, cloneblockname, )
        self.write(res)

    @coroutine
    def delete(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        res = yield self.ceph.delete_rbd(poolname, blockname, )
        self.write(res)

    def patch(self):
        poolname = self.get_argument('poolname')
        blockname = self.get_argument('blockname')
        callback_url = self.get_argument('callback', None)
        params = self.get_argument('callback_params', None)
        IOLoop.instance().add_callback(self.flatten,
                                       **{'poolname': poolname, 'blockname': blockname, 'callback_url': callback_url,
                                          'params': params})
        self.write({'status': 'OK', 'output': ''})

    @run_on_executor
    @coroutine
    def flatten(self, poolname, blockname, callback_url, params):
        start_time = datetime.datetime.now()
        rest = yield self.ceph.flatten_rbd(poolname, blockname)
        end_time = datetime.datetime.now()
        url_params = list()
        run_time = (end_time - start_time).seconds
        access_log.info('flatten It takes %i seconds' % (run_time,))
        if params: url_params = params.split(',')
        if callback_url and u'http' in callback_url:
            url = '{0}?{1}&status={2}&output="{3}"'.format(callback_url, '&'.join(url_params),
                                                           rest.get('status'), rest.get('output'))
            client = Client(str())
            start_time = datetime.datetime.now()
            yield client.get(url)
            end_time = datetime.datetime.now()
            run_time = (end_time - start_time).seconds
            logger.info('flatten {0} and callback address is {1}, timespan {2} seconds'.format(rest.get('status'), url,run_time))


class SyncHandler(BaseRequestHandler):
    middleware = (CheckPoolMiddleware('src_poolname'),)

    @coroutine
    def get(self):
        src_poolname = self.get_argument('src_poolname')
        src_object = self.get_argument('src_object')
        dst_poolname = self.get_argument('dst_poolname')
        dst_object = self.get_argument('dst_object')
        cache_key = u'{0}-{1}-{2}-{3}'.format(src_poolname, src_object, dst_poolname, dst_object)
        self.before_process = yield self.cache('GET', cache_key)
        if self.before_process:
            block_num, block_size, current_num, offset, size, time = yield self.get_object_copy_process(src_poolname,
                                                                                                        src_object)
            self.write({'status': "OK", "output": self.get_schedule(time, block_num, block_size, current_num)})
        else:
            self.write({'status': "ERROR", "output": 'not find process'})

    @coroutine
    def post(self):
        src_poolname = self.get_argument('src_poolname')
        src_object = self.get_argument('src_object')
        dst_poolname = self.get_argument('dst_poolname')
        dst_object = self.get_argument('dst_object')
        dst_url = self.get_argument('dst_url')
        self.cache_key = u'{0}-{1}-{2}-{3}'.format(src_poolname, src_object, dst_poolname, dst_object)
        self.before_process = yield self.cache('GET', self.cache_key)
        block_num, block_size, current_num, offset, size, seconds = yield self.get_object_copy_process(src_poolname,
                                                                                                       src_object)
        if not self.before_process:
            yield self.set_copy_process_in_cache(block_num, block_size, current_num, offset, seconds, size)

            IOLoop.instance().add_callback(self.background_process, **{'dst_url': dst_url, 'dst_poolname': dst_poolname,
                                                                       'dst_object': dst_object,
                                                                       'src_poolname': src_poolname,
                                                                       'src_object': src_object})

        self.write(
            {'status': "OK", "output": self.get_schedule(seconds, block_num, block_size, current_num)})

    @run_on_executor
    def background_process(self, **kwargs):
        self.copy(**kwargs)

    @coroutine
    def copy(self, dst_url, dst_poolname, dst_object, src_poolname, src_object):
        url = self.check_url(dst_url)
        logger.info(url)
        conn = yield websocket_connect(url)
        block_num, block_size, current_num, offset, size, time = yield self.get_object_copy_process(src_poolname,
                                                                                                    src_object)
        yield conn.write_message(
            unicode({u"filename": dst_object, u"poolname": dst_poolname,
                     u"size": size,
                     u"offset": offset}))
        res = yield conn.read_message()
        if res in ('success',):
            try:
                while current_num < block_num:
                    time = datetime.datetime.now()
                    data = yield self.ceph.async_read_full(src_poolname, src_object, block_size,self._id,offset)
                    data = base64.b64encode(data)
                    yield conn.write_message(data)
                    res = yield conn.read_message()
                    if res in ('success',):
                        seconds = (time - datetime.datetime.now()).seconds
                        yield self.set_copy_process_in_cache(block_num, block_size, current_num, offset, seconds, size)
                        current_num += 1
                        offset = offset + block_size
                        if offset + block_size > size:
                            block_size = size - offset
                            if block_size <= 0: current_num += 1
                if current_num == block_num:
                    yield self.cache('del', self.cache_key)
            except WebSocketClosedError as e:
                logger.info(e)
                if current_num < block_num:
                    IOLoop.instance().add_callback(self.background_process,
                                                   **{'dst_url': url, 'dst_poolname': dst_poolname,
                                                      'dst_object': dst_object, 'src_poolname': src_poolname,
                                                      'src_object': src_object})
            finally:
                if current_num >= block_num:
                    yield self.cache('del', self.cache_key)

    @coroutine
    def set_copy_process_in_cache(self, block_num, block_size, current_num, offset, seconds, size):
        yield self.cache('set', self.cache_key,
                         u'{0}-{1}-{2}-{3}-{4}-{5}'.format(block_num, block_size, current_num, offset,
                                                           size, seconds))
        yield self.cache("EXPIRE", self.cache_key, 60)

    def check_url(self, dst_url):
        if dst_url[-1] != u'/':
            dst_url = u''.join((dst_url, u'/',))
        if dst_url[0:3] not in (u'ws:', u'wss'):
            dst_url = u''.join((u'ws://', dst_url,))
        url = u'{0}ws/?{1}'.format(dst_url,
                                   Signature(secret_key=options.secret_key,
                                             **dict(access_key=options.access_key)).get_url())
        return url

    @coroutine
    def get_object_copy_process(self, src_poolname, src_obj):
        if self.before_process:
            params = []
            for var in self.before_process.split('-'):
                params.append(int(var))
            raise Return(tuple(params))
        else:
            file_info = yield self.ceph.object_info(src_poolname, src_obj)
            block_size = options.ceph_block_size
            if file_info.get('status', None) in ('OK',):
                size = int(file_info.get('output')[0])
                offset = 0
                if size > block_size:
                    block_num = int(math.ceil(size / block_size))
                else:
                    block_num = 1
                    block_size = size
                current_num = 0
                raise Return((block_num + 1, block_size, current_num, offset, size, 0))
            else:
                raise Return((0, 0, 0, 0, 0, 0,))

    def get_schedule(self, before_time, block_num, block_size, current_num):
        if not isinstance(before_time, int):
            before_time = int(before_time)
        if before_time == 0:
            speed = 0
        else:
            speed = round(float(block_size) / float(before_time), 4) * 1000
        if block_num > 0:
            schedule = round(float(current_num) / float(block_num - 1), 4) * 100
        else:
            schedule = 0
        return {'schedule': '%.2f%%' % (schedule),
                'speed': '{0}kb/s'.format(speed)}
